﻿
using Microsoft.EntityFrameworkCore;
using MySql.Data.MySqlClient;
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Text;

namespace CheryABP.EntityFrameworkCore
{
    /// <summary>
    /// 数据库查询语句
    /// </summary>
    public class DbContextSqlQueryCommand
    {
        /// <summary>
        /// 使用不含参数的查询语句
        /// </summary>
        /// <param name="query"></param>
        public DbContextSqlQueryCommand(string query)
        {
            Query = query;
        }
        /// <summary>
        /// 使用包含参数的查询语句
        /// </summary>
        /// <param name="query"></param>
        public DbContextSqlQueryCommand(string query, object @params)
        {
            Query = query;
            Parameters = @params;
        }
        /// <summary>
        /// 查询语句
        /// </summary>
        public string Query { get; set; }
        /// <summary>
        /// 参数
        /// </summary>
        public object Parameters { get; set; }
    }

    /// <summary>
    /// 数据库查询语句集合
    /// </summary>
    public class DbContextSqlQueryCommands
    {
        /// <summary>
        /// 数据库为SqlServer时使用的查询语句
        /// </summary>
        public DbContextSqlQueryCommand Sql { get; set; }
        /// <summary>
        /// 数据库为MySql时使用的查询语句
        /// </summary>
        public DbContextSqlQueryCommand MySql { get; set; }
        /// <summary>
        /// 数据库为InMemory时使用的查询语句
        /// </summary>
        public DbContextSqlQueryCommand InMemory { get; set; }
        /// <summary>
        /// 数据库为Sqlite时使用的查询语句
        /// </summary>
        public DbContextSqlQueryCommand Sqlite { get; set; }
    }

    /// <summary>
    /// 数据库类型
    /// </summary>
    public enum DbContextType
    {
        InMemory = 0,
        SqlServer = 1,
        MySql = 2,
        Sqlite = 3,
    }

    /// <summary>
    /// EF上下文扩展
    /// </summary>
    public static class DbContextExtensions
    {
        //拼接参数
        private static void combineParams(DbContextType type, ref DbCommand command, object @params = null)
        {
            if (@params != null)
            {
                Type paramType;
                string prefix;
                switch (type)
                {
                    case DbContextType.InMemory:
                        throw new Exception("未实现的数据库类型");
                    case DbContextType.SqlServer:
                        paramType = typeof(SqlParameter);
                        prefix = "@";
                        break;
                    case DbContextType.MySql:
                        paramType = typeof(MySqlParameter);
                        prefix = "@";
                        break;
                    default:
                        throw new Exception("未实现的数据库类型");
                }
                foreach (var param in @params.GetType().GetProperties())
                {
                    var paramItem = Activator.CreateInstance(paramType, $"{prefix}{param.Name}", (object)param.GetValue(@params));
                    command.Parameters.Add(paramItem);
                }
            }
        }
        //创建命令（同时返回连接符）
        private static DbCommand createCommand(DbContext context, DbContextSqlQueryCommands commands, out DbConnection connection)
        {
            var conn = context.Database.GetDbConnection();
            connection = conn;
            conn.Open();
            var cmd = conn.CreateCommand();
            if (commands.MySql != null && context.Database.IsMySql())
            {
                cmd.CommandText = commands.MySql.Query;
                combineParams(DbContextType.MySql, ref cmd, commands.MySql.Parameters);
            }
            else if (commands.Sql != null && context.Database.IsSqlServer())
            {
                cmd.CommandText = commands.Sql.Query;
                combineParams(DbContextType.SqlServer, ref cmd, commands.Sql.Parameters);
            }
            else if (commands.InMemory != null)
            {
                throw new NotImplementedException();
            }
            return cmd;
        }

        /// <summary>
        /// 执行sql语句，返回受影响行数
        /// </summary>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>受影响行数</returns>
        public static int Exec(this DbContext context, DbContextSqlQueryCommands commands)
        {
            var command = createCommand(context, commands, out var conn);
            var rsl = command.ExecuteNonQuery();
            conn.Close();
            return rsl;
        }

        /// <summary>
        /// 查询数据库
        /// </summary>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>数据DataTable</returns>
        public static DataTable Query(this DbContext context, DbContextSqlQueryCommands commands)
        {
            var command = createCommand(context, commands, out var conn);
            var reader = command.ExecuteReader();
            DataTable dt = new DataTable();
            dt.Load(reader);
            reader.Close();
            conn.Close();
            return dt;
        }

        /// <summary>
        /// 查询数据库，返回多个查询结果集
        /// </summary>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>数据DataSet</returns>
        public static DataSet QuerySet(this DbContext context, DbContextSqlQueryCommands commands)
        {
            var dt = Query(context, commands);
            var ds = new DataSet();
            ds.Tables.Add(dt);
            return ds;
        }

        /// <summary>
        /// 查询数据库，返回IEnumerable的强类型数据
        /// </summary>
        /// <typeparam name="T">查询结果类型</typeparam>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>IEnumerable的强类型数据</returns>
        //public static IEnumerable<T> Query<T>(this DbContext context, DbContextSqlQueryCommands commands)
        //{
        //    var dt = Query(context, commands);
        //    return dt.ToEnumerable<T>();
        //}

        /// <summary>
        /// 查询数据库，返回第一条数据
        /// </summary>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>查询到的第一条数据或null</returns>
        public static DataRow QueryOne(this DbContext context, DbContextSqlQueryCommands commands)
        {
            var dt = Query(context, commands);
            return dt.Rows.Count > 0 ? dt.Rows[0] : null;
        }

        /// <summary>
        /// 查询数据库，返回第一条强类型数据
        /// </summary>
        /// <typeparam name="T">查询结果类型</typeparam>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>查询到的第一条强类型数据</returns>
        //public static T QueryOne<T>(this DbContext context, DbContextSqlQueryCommands commands)
        //{
        //    var dr = QueryOne(context, commands);
        //    return dr.ToObject<T>();
        //}

        /// <summary>
        /// 查询数据库，返回唯一数据
        /// </summary>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>查询到的唯一数据</returns>
        public static object QueryObject(this DbContext context, DbContextSqlQueryCommands commands)
        {
            var command = createCommand(context, commands, out var conn);
            var rsl = command.ExecuteScalar();
            conn.Close();
            return rsl;
        }

        /// <summary>
        /// 查询数据库，返回唯一强类型数据
        /// </summary>
        /// <typeparam name="T">查询结果类型</typeparam>
        /// <param name="context">EF上下文</param>
        /// <param name="commands">数据库查询语句集合</param>
        /// <returns>查询到的唯一强类型数据</returns>
        public static T QueryObject<T>(this DbContext context, DbContextSqlQueryCommands commands)
        {
            return (T)QueryObject(context, commands);
        }
    }
}
